import os
import time
from tqdm import tqdm
from contextlib import suppress
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
import numpy as np
from .zeroshot_classification import accuracy

from sklearn.metrics import classification_report, balanced_accuracy_score

def assign_learning_rate(param_group, new_lr):
    param_group["lr"] = new_lr

def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length

def cosine_lr(optimizer, base_lrs, warmup_length, steps):
    if not isinstance(base_lrs, list):
        base_lrs = [base_lrs for _ in optimizer.param_groups]
    assert len(base_lrs) == len(optimizer.param_groups)
    def _lr_adjuster(step):
        for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
            if step < warmup_length:
                lr = _warmup_lr(base_lr, warmup_length, step)
            else:
                e = step - warmup_length
                es = steps - warmup_length
                lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
            assign_learning_rate(param_group, lr)
    return _lr_adjuster


class Featurizer(torch.nn.Module):
    def __init__(self, model, normalize=True):
        super().__init__()
        self.model = model
        self.normalize = normalize

    def forward(self, input):
        image_features = self.model.encode_image(input)
        if self.normalize:
            image_features = F.normalize(image_features, dim=-1)
        return image_features

class FeatureDataset(Dataset):
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i):
        return self.features[i], self.targets[i]


def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed):
    torch.manual_seed(seed)
    model = torch.nn.Linear(input_shape, output_shape)
    devices = [x for x in range(torch.cuda.device_count())]
    model = model.cuda()
    model = torch.nn.DataParallel(model, device_ids=devices)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )
    criterion = torch.nn.CrossEntropyLoss()

    len_loader = len(dataloader)
    scheduler = cosine_lr(optimizer, lr, 0., epochs * len_loader)

    for epoch in range(epochs):
        end = time.time()
        for i, (x, y) in enumerate(dataloader):
            x, y = x.cuda(), y.cuda()
            step = i + epoch * len_loader
            data_time = time.time() - end
            scheduler(step)

            optimizer.zero_grad()
            with torch.autocast(device, enabled=amp):
                pred = model(x)
                loss = criterion(pred, y)

            loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()

            if (i % 20) == 1:
                num_samples = i * len(x)
                try:
                    samples_per_epoch = len(dataloader)
                    percent_complete = 100.0 * i / len(dataloader)
                    progress_message = f"[{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)]"
                except TypeError:
                    progress_message = f"[{num_samples} samples]"
                print(
                    f"Train Epoch: {epoch} {progress_message}\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}\t"
                    f"LR {optimizer.param_groups[0]['lr']:.5f}"
                )
    return model


def infer(model, dataloader, amp, device):
    true, pred = [], []
    with torch.no_grad():
        for x, y in tqdm(dataloader):
            x = x.to(device)
            y = y.to(device)

            with torch.autocast(device, enabled=amp):
                logits = model(x)

            pred.append(logits.cpu())
            true.append(y.cpu())
            
    logits = torch.cat(pred)
    target = torch.cat(true)
    return logits, target


def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed):
    best_wd_idx, max_acc = 0, 0
    for idx in idxs:
        weight_decay = wd_list[idx]
        model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed)
        logits, target = infer(model, val_loader, amp, device)
        acc1, = accuracy(logits.float(), target.float(), topk=(1,))
        if verbose:
            print(f"Valid accuracy with weight_decay {weight_decay}: {acc1}")
        if max_acc < acc1:
            best_wd_idx, max_acc = idx, acc1
    return best_wd_idx


def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_workers, lr, epochs, 
             model_id, seed, feature_root, device, val_dataloader=None, normalize=True, amp=True, verbose=False):
    assert device == 'cuda' # need to use cuda for this else too slow
    # first we need to featurize the dataset, and store the result in feature_root
    if not os.path.exists(feature_root):
        os.mkdir(feature_root)
    feature_dir = os.path.join(feature_root, model_id)
    if not os.path.exists(feature_dir):
        os.mkdir(feature_dir)
    
    featurizer = Featurizer(model, normalize).cuda()
    if not os.path.exists(os.path.join(feature_dir, 'targets_train.pt')):
        # now we have to cache the features
        devices = [x for x in range(torch.cuda.device_count())]
        featurizer = torch.nn.DataParallel(featurizer, device_ids=devices)

        splits = ["_train", "_val", "_test"]
        for save_str, loader in zip(splits, [train_dataloader, val_dataloader, dataloader]):
            if loader is None:
                continue
            features = []
            targets = []
            num_batches_tracked = 0
            num_cached = 0
            with torch.no_grad():
                for images, target in tqdm(loader):
                    images = images.to(device)

                    with torch.autocast(device, enabled=amp):
                        feature = featurizer(images)
                    
                    features.append(feature.cpu())
                    targets.append(target)

                    num_batches_tracked += 1
                    if (num_batches_tracked % 100) == 0:
                        features = torch.cat(features)
                        targets = torch.cat(targets)
                        
                        torch.save(features, os.path.join(feature_dir, f'features{save_str}_cache_{num_cached}.pt'))
                        torch.save(targets, os.path.join(feature_dir, f'targets{save_str}_cache_{num_cached}.pt'))
                        num_cached += 1
                        features = []
                        targets = []
            
            if len(features) > 0:
                features = torch.cat(features)
                targets = torch.cat(targets)
                torch.save(features, os.path.join(feature_dir, f'features{save_str}_cache_{num_cached}.pt'))
                torch.save(targets, os.path.join(feature_dir, f'targets{save_str}_cache_{num_cached}.pt'))
                num_cached += 1

            features = torch.load(os.path.join(feature_dir, f'features{save_str}_cache_0.pt'))
            targets = torch.load(os.path.join(feature_dir, f'targets{save_str}_cache_0.pt'))
            for k in range(1, num_cached):
                next_features = torch.load(os.path.join(feature_dir, f'features{save_str}_cache_{k}.pt'))
                next_targets = torch.load(os.path.join(feature_dir, f'targets{save_str}_cache_{k}.pt'))
                features = torch.cat((features, next_features))
                targets = torch.cat((targets, next_targets))

            for k in range(num_cached):
                os.remove(os.path.join(feature_dir, f'features{save_str}_cache_{k}.pt'))
                os.remove(os.path.join(feature_dir, f'targets{save_str}_cache_{k}.pt'))

            torch.save(features, os.path.join(feature_dir, f'features{save_str}.pt'))
            torch.save(targets, os.path.join(feature_dir, f'targets{save_str}.pt'))

    features = torch.load(os.path.join(feature_dir, 'features_train.pt'))
    targets = torch.load(os.path.join(feature_dir, 'targets_train.pt'))

    # second, make a dataloader with k features per class. if k = -1, use all features.
    length = len(features)
    perm = [p.item() for p in torch.randperm(length)]
    idxs = []
    counts = {}
    num_classes = 0

    for p in perm:
        target = targets[p].item()
        if target not in counts:
            counts[target] = 0
            num_classes += 1

        if fewshot_k < 0 or counts[target] < fewshot_k:
            counts[target] += 1
            idxs.append(p)

    for c in counts:
        if fewshot_k > 0 and counts[c] != fewshot_k:
            print('insufficient data for this eval')
            return

    train_features = features[idxs]
    train_labels = targets[idxs]
    if val_dataloader is not None:
        features_val = torch.load(os.path.join(feature_dir, 'features_val.pt'))
        targets_val = torch.load(os.path.join(feature_dir, 'targets_val.pt'))
        feature_val_dset = FeatureDataset(features_val, targets_val)
        feature_val_loader = DataLoader(
            feature_val_dset, batch_size=batch_size, 
            shuffle=True, num_workers=num_workers, 
            pin_memory=True,
        )
        feature_train_val_dset = FeatureDataset(np.concatenate((train_features, features_val)), np.concatenate((train_labels, targets_val)))
        feature_train_val_loader = DataLoader(
            feature_train_val_dset, batch_size=batch_size, 
            shuffle=True, num_workers=num_workers, 
            pin_memory=True,
        )
    feature_train_dset = FeatureDataset(train_features, train_labels)
    feature_train_loader = DataLoader(feature_train_dset, batch_size=batch_size, 
                                    shuffle=True, num_workers=num_workers, 
                                    pin_memory=True,
                                )
    features_test = torch.load(os.path.join(feature_dir, 'features_test.pt'))
    targets_test = torch.load(os.path.join(feature_dir, 'targets_test.pt'))
    feature_test_dset = FeatureDataset(features_test, targets_test)
    feature_test_loader = DataLoader(
        feature_test_dset, batch_size=batch_size, 
        shuffle=True, num_workers=num_workers, 
        pin_memory=True,
    )
    input_shape, output_shape = features[0].shape[0], targets.max().item() + 1
    if val_dataloader is not None:
        # perform openAI-like hyperparameter sweep
        # https://arxiv.org/pdf/2103.00020.pdf A.3
        # instead of scikit-learn LBFGS use FCNNs with AdamW
        wd_list = np.logspace(-6, 2, num=97).tolist()
        wd_list_init = np.logspace(-6, 2, num=7).tolist()
        wd_init_idx = [i for i, val in enumerate(wd_list) if val in wd_list_init]
        peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
        step_span = 8
        while step_span > 0:
            left, right = max(peak_idx - step_span, 0), min(peak_idx + step_span, len(wd_list)-1)
            peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
            step_span //= 2
        best_wd = wd_list[peak_idx]
        if fewshot_k < 0:
            # if we are doing full training, we use the full training set (train+val)
            train_loader = feature_train_val_loader
        else:
            # if we are doing few-shot learning, we use the few-shot training set only
            # as adding the validation set will train on more data than intended
            train_loader = feature_train_loader
    else:
        best_wd = 0
        train_loader = feature_train_loader

    final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, amp, device, seed)
    logits, target = infer(final_model, feature_test_loader, amp, device)       
    pred = logits.argmax(axis=1)
    
    # measure accuracy
    if target.max() >= 5:
        acc1, acc5 = accuracy(logits.float(), target.float(), topk=(1, 5))
    else:
        acc1, = accuracy(logits.float(), target.float(), topk=(1,))
        acc5 = float("nan") 
    mean_per_class_recall = balanced_accuracy_score(target, pred)
    fair_info = {
        "weight_decay": best_wd,
        "acc1": acc1,
        "acc5": acc5,
        "mean_per_class_recall": mean_per_class_recall,
        "classification_report": classification_report(target, pred, digits=3)
    }
    if verbose:
        print(fair_info["classification_report"])
        print(f"Test acc1: {acc1} with weight_decay: {best_wd}")
    return {"lp_acc1": fair_info["acc1"], "lp_acc5": fair_info["acc5"], "lp_mean_per_class_recall": fair_info["mean_per_class_recall"], 
            "weight_decay": fair_info['weight_decay'], 'epochs': epochs, 'seed': seed, 'fewshot_k': fewshot_k, 'normalized': normalize}
